'''
sdf2osim.py
convert sdf to osim model
add joint limits

required:
analysis/results/joint tracking/neuromechfly_locomotion_optimization.sdf
analysis/results/joint tracking/osim_example.xml

input:
The sdf file to convert

output:
Name of the output osim file
preprocessed images stored in analysis/results/catmaid results


usage:
python sdf2osim.py -i neuromechfly_noLimits_ground.sdf -o osim.osim

'''
from xml.etree import ElementTree as ET
from xml.dom import minidom
from xml.etree.ElementTree import tostring
import numpy as np
import pandas as pd
from pandas import DataFrame
import argparse

moving_joints = [
    # All right and left front leg joints
    'joint_LFCoxa', 'joint_LFTrochanter', 'joint_LFTibia', 
    'joint_RFCoxa', 'joint_RFTrochanter', 'joint_RFTibia', 
]

parser = argparse.ArgumentParser(description='Translate sdf to osim')
parser.add_argument('-i', '--input', required=True) 
parser.add_argument('-o', '--output', required=True) 
args = parser.parse_args()


sdf = ET.parse(args.input)
sdfroot = sdf.getroot()
fout = args.output

#load template
osim = ET.parse("osim_example.xml")
osim_root = osim.getroot()
model = osim_root.find("Model")
bodySet = model.find("BodySet")
jointSet = model.find("JointSet")

bodySet_objects = bodySet.find("objects")
jointSet_objects = jointSet.find("objects")



sdfmodel = sdfroot.find("model")
#sdfmodel = sdfroot.find("world").find("model")
links = sdfmodel.findall("link")
joints = sdfmodel.findall("joint")

#save link coordinates for futher use
link_poses = DataFrame(columns=['name', 'x', 'y', 'z'])

# get and add link to osim
for link in links:
    name = link.attrib['name']
    pose = link.find("pose").text
    pose = pose.split()
    x = float(pose[0])
    y = float(pose[1])
    z = float(pose[2])
    pose_data = pd.DataFrame({"name":[name],"x":[x],"y":[y],"z":[z]})
    link_poses = pd.concat([link_poses,pose_data])
    inertial = link.find("inertial")
    mass = inertial.find("mass").text
    if float(mass)==0.0:    #fake link skip
        continue
    mass_center = inertial.find("pose").text
    ixx = inertial.find("inertia").find('ixx').text
    ixy = inertial.find("inertia").find('ixy').text
    ixz = inertial.find("inertia").find('ixz').text
    iyy = inertial.find("inertia").find('iyy').text
    iyz = inertial.find("inertia").find('iyz').text
    izz = inertial.find("inertia").find('izz').text
    inertia = ixx  + ' ' + iyy + ' ' + izz + ' ' + ixy + ' ' + ixz + ' '  + iyz  
    mesh_file = link.find("visual").find("geometry").find("mesh").find("uri").text
    mesh_file = mesh_file.replace('../','')
    scale_factors = link.find("visual").find("geometry").find("mesh").find("scale").text

    mesh_file_et = ET.Element('mesh_file')
    mesh_file_et.text = mesh_file
    scale_factors_et = ET.Element('scale_factors')
    scale_factors_et.text = scale_factors
    mesh_et = ET.Element('Mesh',{'name':name+'_geom'})
    mesh_et.append(scale_factors_et)
    mesh_et.append(mesh_file_et)
    ag_et = ET.Element('attached_geometry')
    ag_et.append(mesh_et)
    mass_et = ET.Element('mass')
    mass_et.text = mass
    mass_center_et = ET.Element('mass_center')
    # mass_center_et.text = mass_center #FIXME
    # Get only first 3 elements of mass_center
    mass_center_et.text = ' '.join(mass_center.split()[:3])
    inertia_et = ET.Element('inertia')
    inertia_et.text = inertia
    body_et = ET.Element('Body',{'name':name})
    body_et.append(ag_et)
    body_et.append(mass_et)
    body_et.append(mass_center_et)
    body_et.append(inertia_et)
    bodySet_objects.append(body_et)

#to create joint, first create a dict to store jointname[parent,child,xyz,lower,upper,type]
joint_dict ={}
for joint in joints:
    name = joint.attrib['name']
    if 'joint' not in name:
        continue  #skip none body block joint
    joint_type = joint.attrib['type']
    name_osim = name.replace('_yaw','').replace('_roll','')
    parent = joint.find('parent').text
    child = joint.find('child').text
    xyz = joint.find('axis').find('xyz').text
    xyz = xyz.replace('-','')  # 
    limit = joint.find('axis').find('limit')
    if limit:
        lower = limit.find('lower').text
        upper = limit.find('upper').text
    else:
        lower = 0
        upper = 0
    joint_info = joint_dict.get(name_osim)
    if not joint_info:
        joint_dict[name_osim] = [[parent,child,xyz,lower,upper,joint_type]]
    else:
        joint_info.append([parent,child,xyz,lower,upper,joint_type])

#for joint in joints:
for name,info in joint_dict.items():
    #get joint's parent_block and child_block
    for i in range(len(info)):
        parent = info[i][0]
        child = info[i][1]
        if child.find('_')>=0:
            child = child.split('_')[0]
        if parent.find('_')>=0:
            parent = parent.split('_')[0]
        if parent != child:
            coordinate_orders = [i]  #first coordinate
            break
    # sort coordinate
    while len(coordinate_orders)<len(info):
        last_parent = info[coordinate_orders[-1]][1]
        for i in range(len(info)):
            if info[i][0] == last_parent:
                coordinate_orders.append(i)
                break
    #create PhysicalOffsetFrame
    #joint_type = info[0][5]
    if name not in moving_joints:
        print("weld joint", name)
        joint_et = ET.Element('WeldJoint',{'name':name})
    else:
        print("custom joint", name)
        joint_et = ET.Element('CustomJoint',{'name':name})

    parent_pose = link_poses[link_poses['name']==parent]
    child_pose = link_poses[link_poses['name']==child]
    tf_x = child_pose.iloc[0].x - parent_pose.iloc[0].x
    tf_y = child_pose.iloc[0].y - parent_pose.iloc[0].y
    tf_z = child_pose.iloc[0].z - parent_pose.iloc[0].z
    cf_sp_et = ET.Element('socket_parent')
    cf_sp_et.text = '/bodyset/' + child
    cf_pof_et = ET.Element('PhysicalOffsetFrame',{'name':'cf'})
    cf_pof_et.append(cf_sp_et)
    pf_sp_et = ET.Element('socket_parent')
    pf_sp_et.text = '/bodyset/' + parent
    pf_tf_et = ET.Element('translation')
    pf_tf_et.text = str(tf_x) + ' ' + str(tf_y) + ' ' + str(tf_z)
    pf_or_et = ET.Element('orientation')
    pf_or_et.text = '0 0 0'
    pf_pof_et = ET.Element('PhysicalOffsetFrame',{'name':'pf'})
    pf_pof_et.append(pf_tf_et)
    pf_pof_et.append(pf_or_et)
    pf_pof_et.append(pf_sp_et)
    frames_et = ET.Element('frames')
    frames_et.append(pf_pof_et)
    frames_et.append(cf_pof_et)
    spf_et = ET.Element('socket_parent_frame')
    spf_et.text = 'pf'
    joint_et.append(spf_et)
    scf_et = ET.Element('socket_child_frame')
    scf_et.text = 'cf'
    joint_et.append(scf_et)
    joint_et.append(frames_et)
    
    # if joint_type=='fixed':
    #fixed joint has no SpatialTransform and coordinates,skip
    if name in moving_joints:

        #create coordinates
        coefficients_et = ET.Element('coefficients')
        coefficients_et.text = ' 1 0'
        LinearFunction_et = ET.Element('LinearFunction',{'name':'function'})
        LinearFunction_et.append(coefficients_et)
        
        value_et = ET.Element('value')
        value_et.text = '0'  # <value/>  can not been identified by opensim
        Constant_et = ET.Element('Constant',{'name':'function'})
        Constant_et.append(value_et)

        #initial flags False, set True when used
        flag_roll = False
        flag_pitch = False
        flag_yaw = False

        SpatialTransform_et = ET.Element('SpatialTransform')
        coordinates_et = ET.Element('coordinates')
        i = 0
        #reorder joint rotation axis --fix the coxa-femur joint 
        for order in coordinate_orders:
            i += 1
            xyz = info[order][2]
            lower = info[order][3]
            upper = info[order][4]
            STcoordinates_et = ET.Element('coordinates')
            axis_et = ET.Element('axis')
            TransformAxis_et = ET.Element('TransformAxis',{'name':'rotation'+str(i)})  #rotation1 2 3
            TransformAxis_et.append(LinearFunction_et)
            range_et = ET.Element('range')
            range_et.text = lower + ' ' + upper
            default_value_et = ET.Element('default_value')
            default_value_et.text = '0'
            if xyz=='0.0 0.0 1.0':
                flag_roll = True
                coordinate_name = name+'_roll'
                axis_et.text = '0 0 1'
            elif xyz=='0.0 1.0 0.0':
                flag_pitch = True
                coordinate_name = name+'_pitch'
                axis_et.text = '0 1 0'
            else:   # xyz=='1.0 0.0 0.0'
                flag_yaw = True
                coordinate_name = name+'_yaw'
                axis_et.text = '1 0 0'
            Coordinate_et = ET.Element('Coordinate',{'name':coordinate_name})
            Coordinate_et.append(range_et)
            Coordinate_et.append(default_value_et)
            coordinates_et.append(Coordinate_et)
            STcoordinates_et.text = coordinate_name
            TransformAxis_et.append(STcoordinates_et)
            TransformAxis_et.append(axis_et)
            SpatialTransform_et.append(TransformAxis_et)

        #add other TransformAxis to SpatialTransform
        for j in range(i,3):
            STcoordinates_et = ET.Element('coordinates')
            axis_et = ET.Element('axis')
            if not flag_roll:
                flag_roll = True
                axis_et.text = '0 0 1'
            elif not flag_pitch:
                flag_pitch = True
                axis_et.text = '0 1 0'
            else:
                flag_yaw = True
                axis_et.text = '1 0 0'
            TransformAxis_et = ET.Element('TransformAxis',{'name':'rotation'+str(j+1)})  #rotation1 2 3
            TransformAxis_et.append(Constant_et)
            STcoordinates_et.text = ''
            TransformAxis_et.append(STcoordinates_et)
            TransformAxis_et.append(axis_et)
            SpatialTransform_et.append(TransformAxis_et)

        #add 3 translations TransformAxis to SpatialTransform
        for i in range(1,4):
            TransformAxis_et = ET.Element('TransformAxis',{'name':'translation'+str(i)})
            STcoordinates_et = ET.Element('coordinates')
            STcoordinates_et.text = ''
            axis_et = ET.Element('axis')
            if i==1:
                axis_et.text='0 0 1'
            elif i==2:
                axis_et.text='0 1 0'
            else:
                axis_et.text='1 0 0'
            TransformAxis_et.append(STcoordinates_et)
            TransformAxis_et.append(axis_et)
            TransformAxis_et.append(Constant_et)
            SpatialTransform_et.append(TransformAxis_et)

        joint_et.append(coordinates_et)
        joint_et.append(SpatialTransform_et)
    jointSet_objects.append(joint_et)


#20230324 add joint limits
#joint limits from NeuroMechFly\data\design\sdf\neuromechfly_locomotion_optimization.sdf
sdf_limit = ET.parse("neuromechfly_locomotion_optimization.sdf")
sdfroot_limit = sdf_limit.getroot()
joint_limits = sdfroot_limit.find("world").find("model").findall("joint")
for jl in joint_limits:
    name = jl.attrib['name']
    if 'joint' not in name or joint not in moving_joints:
        continue  #skip none body block joint
    joint_type = jl.attrib['type']
    name_osim = name.replace('_yaw','').replace('_roll','')
    limit = jl.find('axis').find('limit')
    if limit:
        lower = limit.find('lower').text
        upper = limit.find('upper').text
        print(name,lower)
        name_path = './/CustomJoint[@name="'+name_osim+'"]'
        joint_et = jointSet_objects.find(name_path)
        if '_roll' not in name and '_yaw' not in name:
            name = name + '_pitch'
        name_path = './/Coordinate[@name="'+name+'"]'
        coordinates = joint_et.find('coordinates')
        coordinate = coordinates.find(name_path)
        if not coordinate:  # when joint has only 1 dof ,it's name is unsure,we should find the 1st one  
            coordinate = joint_et.find('coordinates').find('Coordinate')
        range_value = coordinate.find('range')
        print(name_path,range_value.text)
        range_value.text = str(lower)+' '+str(upper)


# 20230310 add color
# color: from NeuroMechFly\simulation\bullet_simulation.py
color_wings = ET.Element('color')
color_wings.text = '0.91, 0.96, 0.97'
opacity_wings = ET.Element('opacity')
opacity_wings.text = '0.7'
Appearance_wings = ET.Element('Appearance')
Appearance_wings.append(opacity_wings)
Appearance_wings.append(color_wings)

color_eyes = ET.Element('color')
color_eyes.text = '0.67, 0.21, 0.12'
opacity_eyes = ET.Element('opacity')
opacity_eyes.text = '1'
Appearance_eyes = ET.Element('Appearance')
Appearance_eyes.append(opacity_eyes)
Appearance_eyes.append(color_eyes)

color_legs = ET.Element('color')
color_legs.text = '0.67, 0.51, 0.196'
opacity_legs = ET.Element('opacity')
opacity_legs.text = '1'
Appearance_legs = ET.Element('Appearance')
Appearance_legs.append(opacity_legs)
Appearance_legs.append(color_legs)

color_body = ET.Element('color')
color_body.text = '0.55, 0.392, 0.118'
opacity_body = ET.Element('opacity')
opacity_body.text = '1'
Appearance_body = ET.Element('Appearance')
Appearance_body.append(opacity_body)
Appearance_body.append(color_body)

bodies = bodySet_objects.findall('Body')
for body in bodies:
    body_name = body.attrib['name']
    mesh_et = body.find('attached_geometry').find('Mesh')
    if 'Wing' in body_name:
        mesh_et.append(Appearance_wings)
    elif 'Eye' in body_name:
        mesh_et.append(Appearance_eyes)
    elif 'Coxa' in body_name or 'Femur' in body_name or 'Tibia' in body_name or 'Tarsus' in body_name:
        mesh_et.append(Appearance_legs)
    else:
        mesh_et.append(Appearance_body)


# 20230310 set default pose joint values
# default pose from NeuroMechFly\data\config\pose\pose_default.yaml
# override fixed joint poses from NeuroMechFly\scripts\kinematic_replay\run_kinematic_replay
joint_poses = {
  # fixed_positions from run_kinematic_replay.py
  'joint_A3': -15
  ,'joint_A4': -15
  ,'joint_A5': -15
  ,'joint_A6': -15
  ,'joint_LPedicel': 35
  ,'joint_RPedicel': -35
  ,'joint_Rostrum': 90
  ,'joint_Haustellum': -60
  ,'joint_LWing_roll': 90
  ,'joint_LWing_yaw': -17
  ,'joint_RWing_roll': -90
  ,'joint_RWing_yaw': 17
  ,'joint_Head': 10,

  # from pose_default.yaml
  #,'joint_A3' : -0.8069737300274114
  #,'joint_A4' : -0.1650736529743767
  #,'joint_A5' : -0.08908878107360028
  #,'joint_A6' : 0.029519803564770058
  #,'joint_Head' : 0.5723155574085291
  #,'joint_LAntenna' : -0.002453728960324633
  #,'joint_RAntenna' : 0.002653956767393574
  #,'joint_Rostrum' : -0.10363284957436733
  #,'joint_Haustellum' : -0.008943727172967032
  #,'joint_RWing_roll' : 0.003199068466914467
  #,'joint_RWing_yaw' : -0.009758275213023679
  #,'joint_LWing_roll' : -0.0009181363728491571
  #,'joint_LWing_yaw' : 0.007666345451130245

   'joint_A1A2' : 0.21846014367186156
  ,'joint_Head_roll' : 0.00026173427764161197
  ,'joint_Head_yaw' : 0.0013135272686562638
  ,'joint_LEye' : -0.004643254915867528
  ,'joint_REye' : -0.0004564965851049248
  ,'joint_LFCoxa_roll' : -0.014459294441199725
  ,'joint_LFCoxa_yaw' : 0.1506763651856406
  ,'joint_LFCoxa' : -0.789880258643274
  ,'joint_LFTrochanter' : -67.57373506986399
#   ,'joint_LFFemur' : -67.57373506986399
  ,'joint_LFTrochanter_roll' : -0.10095926849385913
#   ,'joint_LFFemur_roll' : -0.10095926849385913
  ,'joint_LFTibia' : 43.41127909530307
  ,'joint_LFTarsus1' : -8.149394415284323
  ,'joint_LHCoxa' : 1.2104426354333568
  ,'joint_LHCoxa_yaw' : 10.69658678149118
  ,'joint_LHCoxa_roll' : 137.6147129043548
  ,'joint_LHFemur' : -89.38329525236054
  ,'joint_LHFemur_roll' : -8.097307769798867
  ,'joint_LHTibia' : 65.7965836898687
  ,'joint_LHTarsus1' : -6.055704595492672
  ,'joint_LHaltere_roll' : 0.1600892466660107
  ,'joint_LHaltere_yaw' : 0.5451379001713288
  ,'joint_LHaltere' : 0.37692474022375355
  ,'joint_LMCoxa' : 2.0615638403776515
  ,'joint_LMCoxa_yaw' : 12.534912657408768
  ,'joint_LMCoxa_roll' : 101.88137481370443
  ,'joint_LMFemur' : -95.95693707631287
  ,'joint_LMFemur_roll' : -0.19534080978081847
  ,'joint_LMTibia' : 101.0596642359161
  ,'joint_LMTarsus1' : -6.891833221136733
  ,'joint_LWing' : 0.13107426328818048
  ,'joint_RFCoxa_roll' : 0.06824252381150342
  ,'joint_RFCoxa_yaw' : -0.027652522761576103
  ,'joint_RFCoxa' : -0.07589164341686852
#   ,'joint_RFFemur' : -75.10638587459752
  ,'joint_RFTrochanter' : -75.10638587459752
#   ,'joint_RFFemur_roll' : 0.021435915390630007
  ,'joint_RFTrochanter_roll' : 0.021435915390630007
  ,'joint_RFTibia' : 51.350331169288935
  ,'joint_RFTarsus1' : -9.86956643659287
  ,'joint_RHCoxa' : 16.020107700341416
  ,'joint_RHCoxa_yaw' : -7.917740753826325
  ,'joint_RHCoxa_roll' : -139.89327262013938
  ,'joint_RHFemur' : -75.37435088662505
  ,'joint_RHFemur_roll' : -15.287777298628123
  ,'joint_RHTibia' : 65.58923263208715
  ,'joint_RHTarsus1' : -21.98272090882326
  ,'joint_RHaltere_roll' : -0.15051986173758183
  ,'joint_RHaltere_yaw' : -0.16634345859583768
  ,'joint_RHaltere' : 0.14756148784530043
  ,'joint_RMCoxa' : 0.6602396437628356
  ,'joint_RMCoxa_yaw' : -6.046050062746582
  ,'joint_RMCoxa_roll' : -104.6230995031921
  ,'joint_RMFemur' : -105.21451777392465
  ,'joint_RMFemur_roll' : 0.11446637542072276
  ,'joint_RMTibia' : 98.99614155554471
  ,'joint_RMTarsus1' : -1.9125087120067328
  ,'joint_RWing' : 0.1297395296771384

}

for joint_name,joint_angle in joint_poses.items():
    jname = joint_name.replace('_roll','').replace('_yaw','')
    if jname not in moving_joints:
        continue
    else:
        name_path = './/CustomJoint[@name="'+jname+'"]'
        
    joint_et = jointSet_objects.find(name_path)
    if '_roll' not in joint_name and '_yaw' not in joint_name:
        joint_name = joint_name + '_pitch'
    name_path = './/Coordinate[@name="'+joint_name+'"]'

    # from IPython import embed; embed()
    coordinates = joint_et.find('coordinates')
    coordinate = coordinates.find(name_path)

    if not coordinate:  # when joint has only 1 dof ,it's name is unsure, find 1st one  
        coordinate = joint_et.find('coordinates').find('Coordinate')
    default_value = coordinate.find('default_value')
    default_value.text = str(joint_angle/180*np.pi)


tree=ET.ElementTree(osim_root)
xmlstr = minidom.parseString(ET.tostring(tree.getroot())).toprettyxml(indent="   ")
with open(fout, "w") as f:
    f.write(xmlstr)
# tree.write(fout,encoding='utf-8',short_empty_elements=False)
